import os
import json
from typing import Dict, List, Optional, Tuple
from itertools import product
import asyncio
from tqdm import tqdm
from datetime import datetime   
from utils import *
from retrieval_base import retrieval_base
from retrieval_benchmark import retrieval_benchmark_base

class Heuristic_RAG_Retrieval(retrieval_base):
    '''
    The refinement times ideally should be set to the length of trajectory -1, since we count the readme as the first link in the trajectory and it would essentially be the 0 iteration. 
    
    For example, the trajectory of a repo is [https://github.com/ValveSoftware/GameNetworkingSockets, https://github.com/ValveSoftware/GameNetworkingSockets/blob/master/BUILDING.md], and then the refinement times should be 1.
    '''
    def __init__(self, repo_dir, repo_full_name, refine_times=3):
        super().__init__(repo_dir, repo_full_name)
        self.refine_times = refine_times
        self.trajectories_data = []
        self.trajectories_data, self.final_target_url = self.rag()

    def predict_target_link(self):
        return self.final_target_url
    
    def predict_retrieval_trajectory(self):
        return self.get_trajectory_from_trajectory_data()
        
    def get_trajectory_from_trajectory_data(self):
        '''
        Since our design essentially would traverse through all the links that are found in the previous iteration, here we record the visited urls in each iteration. Then we would combine the links from each iteration by taking one links from each iteration and also preserving the order of the iterations.
        '''
        trajectories = []
        for refinement_data in self.trajectories_data:
            links_visited_in_iteration = []
            for visit in refinement_data['visits']:
                links_visited_in_iteration.append(visit['link'])
            trajectories.append(links_visited_in_iteration)
        
        self.all_links = [item for sublist in trajectories for item in sublist]
        self.trajectories = list(product(*trajectories))
        return self.trajectories

    def rag(self):
        readme_summary = self.readme_content
        # print("readme_summary", readme_summary)
        self.all_link_data = [] ### Example [{url:llm summary of the url}, ...]
        refine = False

        # Clear or (re)init) the trajectories on each refine_links call
        self.trajectories_data = []

        for i in range(self.refine_times):
            print("Refining the links for the", i+1, "time")
            
            # Here, refine is always false to make sure the LLM is summarizing each url without any context, as there is none anyway
            # if i >= 1:
            #     refine = True
            

            # Step 1: Summarize the readme text to get instructions and links
            structured_readme = summarize_text(
                text=readme_summary,
                refine=refine,
                response_format=Extract_Url_Summary_and_External_Links,
                base_url=self.base_url,
            )

            print("structured_readme_content_and_links", structured_readme)
            readme_summary = structured_readme.Build_Instructions
            external_links = structured_readme.External_URLs
            internal_links = structured_readme.Internal_Paths
            
            self.all_link_data.append({"url": self.readme_path, "summary": readme_summary})
            
            link_dict = {
                'external': external_links,
                'internal': internal_links,
            }

            # We'll keep a record of all visits in this iteration
            iteration_trajectory = {
                "iteration": i + 1,
                "visits": []
            }

            # Step 2: Summarize each external link
            for ext_link in external_links:
                try:
                    structured_ext_summary = asyncio.run(summarize_link(ext_link))
                    ext_build_instructions = structured_ext_summary.Build_Instructions
                    readme_summary += (
                        f"External link: {ext_link}\n\n "
                        f"Extracted information: {ext_build_instructions}\n\n "
                        f"Additional_External_Links: {structured_ext_summary.External_URLs}\n\n"
                    )
                    
                    self.all_link_data.append({"url": ext_link, "summary": ext_build_instructions})
                    
                    # Record the “visit” for external link
                    iteration_trajectory["visits"].append({
                        "link_type": "external",
                        "link": ext_link,
                        "summary": ext_build_instructions,
                        "additional_external_links": structured_ext_summary.External_URLs
                    })
                except Exception as e:
                    print(f"Error summarizing external link {ext_link}: {e}")
                    continue

            # Step 3: Summarize each internal link
            for int_link in internal_links:
                try:
                    structured_int_summary = summarize_text(
                        text=read_file(int_link),
                        response_format=Extract_Build_Information_and_Links, base_url=self.base_url
                    )
                    int_build_instructions = structured_int_summary.Build_Instructions
                    readme_summary += (
                        f"Internal link: {int_link}\n "
                        f"Extracted information: {int_build_instructions}\n\n "
                        f"Additional_External_Links: {structured_int_summary.External_URLs}. "
                        f"Additional_Internal_Links: {structured_int_summary.Internal_Paths}\n\n"
                    )
                    self.all_link_data.append({"url": int_link, "summary": int_build_instructions})

                    # Record the “visit” for internal link
                    iteration_trajectory["visits"].append({
                        "link_type": "internal",
                        "link": int_link,
                        "summary": int_build_instructions,
                        "additional_external_links": structured_int_summary.External_URLs,
                        "additional_internal_links": structured_int_summary.Internal_Paths
                    })
                except Exception as e:
                    print(f"Error summarizing internal link {int_link}: {e}")
                    continue

            # Add the iteration's visited links to the full class-level trajectories
            self.trajectories_data.append(iteration_trajectory)

            print("Summarized content after refining the links for the", i+1, "time")
            print("*"*50)
            print(readme_summary)
            print("*"*50)

        # Final result
        string_data = json.dumps(self.all_link_data, ensure_ascii=False, indent=2)
        self.final_target_url = rag_summarize_link(string_data, repo_name=self.repo_name)

        return self.trajectories_data, self.final_target_url        

class Heuristic_RAG_Retrieval_Benchmark(retrieval_benchmark_base):
    def __init__(self, input_raw_data_path, output_benchmark_path, cloned_repos_dir, output_retrieval_results_file_path, pre_computed_benchmark_file_path=None, pre_computed_retrieval_results_path=None, refine_times = 3, **kwargs):
        self.refine_times = refine_times        
        super().__init__(input_raw_data_path, output_benchmark_path, cloned_repos_dir, pre_computed_benchmark_file_path, output_retrieval_results_file_path, pre_computed_retrieval_results_path)
        
    
    def evaluate_trajectory(self, index, predicted_trajectories):
        """
        Evaluate the predicted trajectories against the ground truth trajectory for a given repository.

        Parameters:
        - index (int): The index of the repository in the retrieval benchmark dataset.
        - predicted_trajectories (List[List[str]]): A list of predicted trajectories, where each trajectory is a list of URLs.

        Returns:
        - List[str]: The selected trajectory, which is either the exact match with the ground truth trajectory or the one with the highest coverage.
        """
        ground_truth_trajectory = self.get_ground_truth_trajectory(index) 

        ### NOTE: This loop is for Heuristic Retrieval only, since Heuristic Retrieval would return all trajectories, as we assume LLM has seen all the links in the trajectory.
        trajectory_coverage_list = []
        len_ground_truth_trajectory = len(ground_truth_trajectory)
        
        for trajectory in predicted_trajectories:
            if trajectory == ground_truth_trajectory:
                # If the trajectory is the same as the ground truth trajectory, return immediately
                self.trajectory_accuracy += 1
                self.trajectory_coverage += 1
                self.trajectory_length += len(trajectory)
                return trajectory
            else:
                # Otherwise, calculate the coverage of the predicted trajectory against ground truth trajectory
                temp_coverage = self.calculate_trajectory_coverage(predicted_trajectory=trajectory, ground_truth_trajectory=ground_truth_trajectory)
                trajectory_coverage_list.append(temp_coverage)
            
        # If no trajectory is the same as the ground truth trajectory, return the trajectory with the highest coverage
        # The coverage will not be added to the accuracy but purely for selection usage, since we consider that partial trajectory provides 0 useful information for our task
        if trajectory_coverage_list:
            max_coverage = max(trajectory_coverage_list)
            max_coverage_index = trajectory_coverage_list.index(max_coverage)
            selected_predicted_trajectory = predicted_trajectories[max_coverage_index]
            self.trajectory_length += len(selected_predicted_trajectory)
            self.trajectory_coverage += max_coverage
        else:
            # Fallback: return an empty trajectory or handle as needed
            selected_predicted_trajectory = []
        return selected_predicted_trajectory
    
    def evaluate_target_link(self, index, predicted_target_link):
        return super().evaluate_target_link(index, predicted_target_link)
        
    def generate_single_retrieval_result(self, index):
        try:
            benchmark_data = self.get_item(index)
            repo_dir = benchmark_data['repo_dir']
            repo_full_name = f"{benchmark_data['repo_url'].split('/')[-2]}/{benchmark_data['repo_url'].split('/')[-1]}"
            retrieval_class = Heuristic_RAG_Retrieval(repo_dir, repo_full_name, refine_times=self.refine_times)
            target_trajectories = retrieval_class.predict_retrieval_trajectory()
            # First we evaluate the trajectory, and get the selected predicted trajectory for the target link prediction
            selected_predicted_trajectory = self.evaluate_trajectory(index, target_trajectories)
            predicted_target_link = retrieval_class.predict_target_link()
            self.evaluate_target_link(index, predicted_target_link)
            
            retrieval_results = {  
                "repo_name": benchmark_data['repo_name'],
                "repo_dir": benchmark_data['repo_dir'],
                "ground_truth_trajectory": benchmark_data['retrieval_trajectory'],
                "predicted_trajectory": selected_predicted_trajectory,            "ground_truth_target_link": benchmark_data['retrieval_target_link'],
                "predicted_target_link": predicted_target_link
            }
            individual_retrieval_directory = os.path.dirname(self.output_retrieval_results_file_path)
            timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
            individual_retrieval_file_path = os.path.join(individual_retrieval_directory, f"{benchmark_data['repo_name']}_retrieval_results_{timestamp}.json")
            with open(individual_retrieval_file_path, 'w') as f:
                json.dump(retrieval_results, f, indent=4)
            print(f"Individual retrieval results saved to {individual_retrieval_file_path}")
            return retrieval_results

        except Exception as e:
            print(f"Error generating retrieval result for index {index}: {e}")
            return {}
        


    
## For Testing
if __name__ == '__main__':
    # repo_dir = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/cloned_repos/catboost'
    # data = Heuristic_RAG_Retrieval(repo_dir, 'catboost/catboost', 3)
    # print(data.predict_retrieval_trajectory())
    # print(data.predict_target_link())
    
    # pre_computed_retrieval_results_dir = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/experiment_results/retrieval_results/heuristic_rag'
    
    # retrieval_data = []
    # for file in os.listdir(pre_computed_retrieval_results_dir):
    #     if file.endswith('.json'):
    #         with open(os.path.join(pre_computed_retrieval_results_dir, file), 'r') as f:
    #             data = json.load(f)
    #             retrieval_data.append(data)
    
    # with open("rag_output.json", 'w') as f:
    #     json.dump(retrieval_data, f, indent=4)
    
    benchmark = Heuristic_RAG_Retrieval_Benchmark(input_raw_data_path = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/data/repo_list_76.json',
                                                  output_benchmark_path = "/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/data/retrieval_benchmark_76.json" , cloned_repos_dir = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/cloned_repos', output_retrieval_results_file_path = '/mnt/midnight/steven_zhang/LLM_assisted_compilation/rag_out_results.json', pre_computed_benchmark_file_path='/mnt/midnight/steven_zhang/LLM_assisted_compilation/Compilation_Benchmark/data/retrieval_benchmark_76.json', pre_computed_retrieval_results_path='/mnt/midnight/steven_zhang/LLM_assisted_compilation/rag_output.json', refine_times = 3,
    )